{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Head Selection\n",
    "**_BERT_** is a **_Multi-layer_ _Multi-Head_** Transformer architecture. As discuss in many of the current reseachers, different Attention heads captures different lingustic patterns. For a better deletion of words using Attention mechanism we need to choose a head which **captures pattern useful for classification.**\n",
    "\n",
    "To do this we are using a Brute force mechanism to seach through all the possible heads. We are deleting TopK words attended by different heads from the sentence and measuring the new classification score. In case of sentiments, removing sentiments related words makes the sentence neutral. The heads are sorted by the amount to which it is able to make the sentences from dev set to Neutral."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import logging\n",
    "import os\n",
    "import random\n",
    "import sys\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,\n",
    "                              TensorDataset)\n",
    "from torch.utils.data.distributed import DistributedSampler\n",
    "from tqdm import tqdm, trange\n",
    "\n",
    "from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE\n",
    "from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME\n",
    "#from pytorch_pretrained_bert.tokenization import BertTokenizer\n",
    "from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear\n",
    "\n",
    "from bertviz.bertviz import attention, visualization\n",
    "from bertviz.bertviz.pytorch_pretrained_bert import BertModel, BertTokenizer\n",
    "\n",
    "logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',\n",
    "                    datefmt = '%m/%d/%Y %H:%M:%S',\n",
    "                    level = logging.INFO)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger = logging.getLogger(__name__)\n",
    "bert_classifier_model_dir = \"./bert_classifier/\" ## Path of BERT classifier model path\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "n_gpu = torch.cuda.device_count()\n",
    "logger.info(\"device: {}, n_gpu {}\".format(device, n_gpu))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Model for performing Classification\n",
    "model_cls = BertForSequenceClassification.from_pretrained(bert_classifier_model_dir, num_labels=2)\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)\n",
    "model_cls.to(device)\n",
    "model_cls.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Model to get the attention weights of all the heads\n",
    "model = BertModel.from_pretrained(bert_classifier_model_dir)\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)\n",
    "model.to(device)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_seq_len=70 # Maximum sequence length \n",
    "sm = torch.nn.Softmax(dim=-1) ## Softmax over the batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_multiple_examples(input_sentences, bs=32):\n",
    "    \"\"\"\n",
    "    This fucntion returns classification predictions for batch of sentences.\n",
    "    input_sentences: list of strings\n",
    "    bs : batch_size : int\n",
    "    \"\"\"\n",
    "    \n",
    "    ## Prepare data for classification\n",
    "    ids = []\n",
    "    segment_ids = []\n",
    "    input_masks = []\n",
    "    pred_lt = []\n",
    "    for sen in input_sentences:\n",
    "        text_tokens = tokenizer.tokenize(sen)\n",
    "        tokens = [\"[CLS]\"] + text_tokens + [\"[SEP]\"]\n",
    "        temp_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
    "        input_mask = [1] * len(temp_ids)\n",
    "        segment_id = [0] * len(temp_ids)\n",
    "        padding = [0] * (max_seq_len - len(temp_ids))\n",
    "\n",
    "        temp_ids += padding\n",
    "        input_mask += padding\n",
    "        segment_id += padding\n",
    "        \n",
    "        ids.append(temp_ids)\n",
    "        input_masks.append(input_mask)\n",
    "        segment_ids.append(segment_id)\n",
    "    \n",
    "    ## Convert input lists to Torch Tensors\n",
    "    ids = torch.tensor(ids)\n",
    "    segment_ids = torch.tensor(segment_ids)\n",
    "    input_masks = torch.tensor(input_masks)\n",
    "    \n",
    "    steps = len(ids) // bs\n",
    "    \n",
    "    for i in range(steps+1):\n",
    "        if i == steps:\n",
    "            temp_ids = ids[i * bs : len(ids)]\n",
    "            temp_segment_ids = segment_ids[i * bs: len(ids)]\n",
    "            temp_input_masks = input_masks[i * bs: len(ids)]\n",
    "        else:\n",
    "            temp_ids = ids[i * bs : i * bs + bs]\n",
    "            temp_segment_ids = segment_ids[i * bs: i * bs + bs]\n",
    "            temp_input_masks = input_masks[i * bs: i * bs + bs]\n",
    "        \n",
    "        temp_ids = temp_ids.to(device)\n",
    "        temp_segment_ids = temp_segment_ids.to(device)\n",
    "        temp_input_masks = temp_input_masks.to(device)\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            preds = sm(model_cls(temp_ids, temp_segment_ids, temp_input_masks))\n",
    "        pred_lt.extend(preds.tolist())\n",
    "    \n",
    "    return pred_lt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_file(path,size):\n",
    "    with open(path) as fp:\n",
    "        data = fp.read().splitlines()[:size]\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_attention_for_batch(input_sentences, bs=32):\n",
    "    \"\"\"\n",
    "    This function calculates attention weights of all the heads and\n",
    "    returns it along with the encoded sentence for further processing.\n",
    "    \n",
    "    input sentence: list of strings\n",
    "    bs : batch_size\n",
    "    \"\"\"\n",
    "    \n",
    "    ## Preprocessing for BERT \n",
    "    ids = []\n",
    "    segment_ids = []\n",
    "    input_masks = []\n",
    "    pred_lt = []\n",
    "    ids_for_decoding = []\n",
    "    for sen in input_sentences:\n",
    "        text_tokens = tokenizer.tokenize(sen)\n",
    "        tokens = [\"[CLS]\"] + text_tokens + [\"[SEP]\"]\n",
    "        temp_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
    "        input_mask = [1] * len(temp_ids)\n",
    "        segment_id = [0] * len(temp_ids)\n",
    "        padding = [0] * (max_seq_len - len(temp_ids))\n",
    "        \n",
    "        ids_for_decoding.append(tokenizer.convert_tokens_to_ids(tokens))\n",
    "        temp_ids += padding\n",
    "        input_mask += padding\n",
    "        segment_id += padding\n",
    "        \n",
    "        ids.append(temp_ids)\n",
    "        input_masks.append(input_mask)\n",
    "        segment_ids.append(segment_id)\n",
    "    ## Convert the list of int ids to Torch Tensors\n",
    "    ids = torch.tensor(ids)\n",
    "    segment_ids = torch.tensor(segment_ids)\n",
    "    input_masks = torch.tensor(input_masks)\n",
    "    \n",
    "    steps = len(ids) // bs\n",
    "    \n",
    "    for i in trange(steps+1):\n",
    "        if i == steps:\n",
    "            temp_ids = ids[i * bs : len(ids)]\n",
    "            temp_segment_ids = segment_ids[i * bs: len(ids)]\n",
    "            temp_input_masks = input_masks[i * bs: len(ids)]\n",
    "        else:\n",
    "            temp_ids = ids[i * bs : i * bs + bs]\n",
    "            temp_segment_ids = segment_ids[i * bs: i * bs + bs]\n",
    "            temp_input_masks = input_masks[i * bs: i * bs + bs]\n",
    "        \n",
    "        temp_ids = temp_ids.to(device)\n",
    "        temp_segment_ids = temp_segment_ids.to(device)\n",
    "        temp_input_masks = temp_input_masks.to(device)\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            _, _, attn = model(temp_ids, temp_segment_ids, temp_input_masks)\n",
    "        \n",
    "        # Add all the Attention Weights to CPU memory\n",
    "        # Attention weights for each layer is stored in a dict 'attn_prob'\n",
    "        for k in range(12):\n",
    "            attn[k]['attn_probs'] = attn[k]['attn_probs'].to('cpu')\n",
    "        \n",
    "        '''\n",
    "        attention weights are stored in this way:\n",
    "        att_lt[layer]['attn_probs']['input_sentence']['head']['length_of_sentence']\n",
    "        '''\n",
    "        # Concate Attention weights for all the examples in the list att_lt[layer_no]['attn_probs']\n",
    "        \n",
    "        if i == 0:\n",
    "            att_lt = attn\n",
    "            heads = len(att_lt)\n",
    "        else:\n",
    "            for j in range(heads):\n",
    "                att_lt[j]['attn_probs'] = torch.cat((att_lt[j]['attn_probs'],attn[j]['attn_probs']),0)\n",
    "        \n",
    "    \n",
    "    return att_lt, ids_for_decoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_sentences(input_sentences, att, decoding_ids, threshold=0.25):\n",
    "    \"\"\"\n",
    "    This function processes each input sentence by removing the top tokens defined threshold value.\n",
    "    Each sentence is processed for each head.\n",
    "    \n",
    "    input_ids: list of strings\n",
    "    decoding_ids: indexed input_sentnces thus len(input_sentences) == len(decoding_ids)\n",
    "    threshold: Percentage of the top indexes to be removed\n",
    "    \"\"\"\n",
    "    # List of None of num_of_layers * num_of_heads to save the results of each head for input_sentences\n",
    "    \n",
    "    lt = [None for x in range(len(att) * len(att[0]['attn_probs'][0]))]\n",
    "    #print(len(lt))\n",
    "    \n",
    "    inx = 0\n",
    "    for i in trange(len(att)): #  For all the layers\n",
    "        for j in range(len(att[i]['attn_probs'][0])): # For all the heads in the ith Layer\n",
    "            processed_sen = [None for q in decoding_ids] # List of len(decoding_ids)\n",
    "            for k in range(len(input_sentences)): # For all the senteces \n",
    "                _, topi = att[i]['attn_probs'][k][j][0].topk(len(decoding_ids[k])) # Get top attended ids\n",
    "                topi = topi.tolist()\n",
    "                topi = topi[:int(len(topi) * threshold)] \n",
    "                ## Decode the sentece after removing the topk indexes\n",
    "                final_indexes = []\n",
    "                count = 0\n",
    "                count1 = 0\n",
    "                tokens = [\"[CLS]\"] + tokenizer.tokenize(input_sentences[k]) + [\"[SEP]\"]\n",
    "                while count < len(decoding_ids[k]):\n",
    "                    if count in topi: # Remove index if present in topk\n",
    "                        while (count + count1 + 1) < len(decoding_ids[k]):\n",
    "                            if \"##\" in tokens[count + count1 + 1]:\n",
    "                                count1 += 1\n",
    "                            else:\n",
    "                                break\n",
    "                        count += count1\n",
    "                        count1 = 0\n",
    "                    else: # Else add to the decoded sentence\n",
    "                        final_indexes.append(decoding_ids[k][count])\n",
    "                    count += 1\n",
    "                tmp = tokenizer.convert_ids_to_tokens(final_indexes) # Convert ids to token\n",
    "                # Convert toknes to sentence\n",
    "                processed_sen[k] = \" \".join(tmp).replace(\" ##\", \"\").replace(\"[CLS]\",\"\").replace(\"[SEP]\",\"\").strip()\n",
    "            lt[inx] = processed_sen # Store sentences for inxth head\n",
    "            inx += 1\n",
    "    \n",
    "    return lt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_block_head(processed_sentence_list, lmbd = 0.1):\n",
    "    \"\"\"\n",
    "    This function calculate classification scores for sentences generated by each head\n",
    "    and sort them from best to worst.\n",
    "    score = min(pred) + lmbd / max(pred) + lmbd, lmbd is smoothing param\n",
    "    pred is list of probability score for each class, for best case pred = [0.5, 0.5] ==> score = 1\n",
    "    \n",
    "    it returns sorted list of (Layer, Head, Score)\n",
    "    \"\"\"\n",
    "    scores = {}\n",
    "    #scores_1 = {}\n",
    "    for i in trange(len(processed_sentence_list)): # sentences by each head\n",
    "        pred = np.array(run_multiple_examples(processed_sentence_list[i]))\n",
    "        scores[i] = np.mean([(min(x[0], x[1])+lmbd)/(max(x[0], x[1])+lmbd) for x in pred])\n",
    "        #scores_1[i] = np.mean([abs(max(x[0],x[1]) - min(x[0],x[1])) for x in pred])\n",
    "    temp = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)\n",
    "    #temp1 = sorted(scores_1.items(), key=lambda kv: kv[1], reverse=False)\n",
    "    score_lt = [(x // 12, x - (12 * (x // 12)),y) for x,y in temp]\n",
    "    #score1_lt = [(x // 12, x - (12 * (x // 12)),y) for x,y in temp1]\n",
    "    return score_lt  #score1_lt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pos_examples_file = \"/home/ubuntu/bhargav/data/yelp/sentiment_dev_1.txt\"\n",
    "neg_examples_file = \"/home/ubuntu/bhargav/data/yelp/sentiment_dev_0.txt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "100 examples from each class worked good, the bottlenack is the run_multiple_examples() function,\n",
    "with higher memory (either with cpu of gpu) one can reduce the processing time by incresing batch_size.\n",
    "With batch_size of 32 it takes around 24 mins for 100 example on cpu.\n",
    "'''\n",
    "pos_data = read_file(pos_examples_file,100)\n",
    "neg_data = read_file(neg_examples_file,100)\n",
    "data = pos_data + neg_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(pos_data), len(neg_data), len(data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "att, decoding_ids = get_attention_for_batch(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sen_list = process_sentences(data, att, decoding_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = get_block_head(sen_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
